library(nloptr)
library(tidyverse)
library(ggplot2)
library(PSweight)
library(nleqslv)
library(R.utils)
library(ggpattern)

set.seed(19248)

true_p0 <- c(3,2,4,9,12,2)
true_p1 <- c(5,1,1,10,9,20)

true_p0_sd <- c(1,1,2,3,4,1)
true_p1_sd <- c(2,1,1,5,2,6)

## true proportion
true_p <- rep(1/6,6)

true_p1*(1-true_p1)/true_p1
true_p0*(1-true_p0)/true_p0

tau_true <-  true_p1-true_p0
tau_true_all <- sum(tau_true*true_p)

GenX <- function(n){
  sample(1:6, size = n, replace = TRUE, prob = true_p)
}

GenS <- function(X){
  
  S1 <- (X ==1)
  S2 <- (X==2)
  S3 <- (X==3)
  S4 <- (X==4)
  S5 <- (X==5)
  S6 <- (X==6)
  
  S <- cbind(S1,S2,S3,S4,S5,S6)
  
  colnames(S) <- c("A","B","C","D","E","F")
  
  return(S)
}



GenY <- function(n,X,Tr,S,mu.t,mu.c,sd.t,sd.c){
  
  Y <- NULL
  
  for(i in 1:n){
    
    if(Tr[i]==1){
      if(S[i,1]){
        Y[i] <- rnorm(1,mu.t[1],sd.t[1])
      }else if (S[i,2]){
        Y[i] <- rnorm(1,mu.t[2],sd.t[2])
      }else if (S[i,3]){
        Y[i] <- rnorm(1,mu.t[3],sd.t[3])
      }else if (S[i,4]){
        Y[i] <- rnorm(1,mu.t[4],sd.t[4])
      }else if (S[i,5]){
        Y[i] <- rnorm(1,mu.t[5],sd.t[5])
      }else{
        Y[i] <- rnorm(1,mu.t[6],sd.t[6])
      }
    }
    
    
    # control arm
    if(Tr[i]==0){
      if(S[i,1]){
        Y[i] <- rnorm(1,mu.c[1],sd.c[1])
      }else if(S[i,2]){
        Y[i] <- rnorm(1,mu.c[2],sd.c[2])
      }else if(S[i,3]){
        Y[i] <- rnorm(1,mu.c[3],sd.c[3])
      }else if(S[i,4]){
        Y[i] <- rnorm(1,mu.c[4], sd.c[4])
      }else if(S[i,5]){
        Y[i] <- rnorm(1,mu.c[5],sd.c[5])
      }else{
        Y[i] <- rnorm(1,mu.c[6],sd.c[6])
      }
      
    }
  }
  
  return(Y)
}



SubAlloc <- function(tau,sigma1_vec, sigma0_vec,
                     envy_c,lower_c,upper_c,p_vec,N_t,mu1_vec,mu0_vec){
  
  # p_vec <- true_p
  delta_N <- sqrt(log(N_t)/N_t)
  
  # objective function
  eval_f0 <- function(x){
    res <- p_vec[1]*(sigma1_vec[1]^2/x[1]+sigma0_vec[1]^2/(1-x[1]))+ 
      p_vec[2]*(sigma1_vec[2]^2/x[2]+sigma0_vec[2]^2/(1-x[2]))+
      p_vec[3]*(sigma1_vec[3]^2/x[3]+sigma0_vec[3]^2/(1-x[3]))+
      p_vec[4]*(sigma1_vec[4]^2/x[4]+sigma0_vec[4]^2/(1-x[4]))+
      p_vec[5]*(sigma1_vec[5]^2/x[5]+sigma0_vec[5]^2/(1-x[5]))+
      p_vec[6]*(sigma1_vec[6]^2/x[6]+sigma0_vec[6]^2/(1-x[6]))
    return( res )
  }
  
  # constraint function
  eval_g0 <- function(x) {
    envy_constr <- c(x[1]-x[2]-envy_c,
                     -envy_c-x[1]+x[2],
                     x[1]-x[3]-envy_c,
                     -envy_c-x[1]+x[3],
                     x[1]-x[4]-envy_c,
                     -envy_c-x[1]+x[4],
                     x[1]-x[5]-envy_c,
                     -envy_c-x[1]+x[5],
                     x[1]-x[6]-envy_c,
                     -envy_c-x[1]+x[6],
                     x[2]-x[3]-envy_c,
                     -envy_c-x[2]+x[3],
                     x[2]-x[4]-envy_c,
                     -envy_c-x[2]+x[4],
                     x[2]-x[5]-envy_c,
                     -envy_c-x[2]+x[5],
                     x[2]-x[6]-envy_c,
                     -envy_c-x[2]+x[6],
                     x[3]-x[4]-envy_c,
                     -envy_c-x[3]+x[4],
                     x[3]-x[5]-envy_c,
                     -envy_c-x[3]+x[5],
                     x[3]-x[6]-envy_c,
                     -envy_c-x[3]+x[6],
                     x[4]-x[5]-envy_c,
                     -envy_c-x[4]+x[5],
                     x[4]-x[6]-envy_c,
                     -envy_c-x[4]+x[6],
                     x[5]-x[6]-envy_c,
                     -envy_c-x[5]+x[6]
    )
    welfare_constr <- c(
      -delta_N-log(x[1]/(1-x[1]))*tau[1],
      -delta_N-log(x[2]/(1-x[2]))*tau[2],
      -delta_N-log(x[3]/(1-x[3]))*tau[3],
      -delta_N-log(x[4]/(1-x[4]))*tau[4],
      -delta_N-log(x[5]/(1-x[5]))*tau[5],
      -delta_N-log(x[6]/(1-x[6]))*tau[6])
    return( c(envy_constr,welfare_constr))
  }
  
  # Solve using NLOPT_LN_COBYLA without gradient information
  res1 <- nloptr( x0=rep(0.01,6),
                  eval_f=eval_f0,
                  lb = rep(lower_c,6),
                  ub = rep(upper_c,6),
                  eval_g_ineq = eval_g0,
                  opts = list("algorithm"="NLOPT_LN_COBYLA",
                              "xtol_rel"=1.0e-8))
  
  # optimal e*
  e_star <- res1$solution[1:6]
  
  names(e_star) <- c("A","B","C","D","E","F")
  
  return(e_star)
}


envy_c <- 0.1

Sim <- function(m,n,envy_c){
  
  # First stage
  X_1 <- GenX(m*6)
  S_1 <- GenS(X_1)
  
  # Stage 1: Assign treatment randomly
  T_1 <- rbinom(m*6,1,0.5)
  
  # Generate Y_1
  Y_1 <-  GenY(m*6,X_1,T_1,S_1,true_p1,true_p0,true_p1_sd,true_p0_sd)
  e_1 <- rep(1/2, 6) # randomly assign treatment with e=1/2
  p_1 <- colSums(S_1)/m*6
  
  # estimate subgroup ATE: use PSWeight
  tau_1 <- sd_1.t <- sd_1.c <- mu1_1 <- mu0_1 <- sd_1_db.t <- sd_1_db.c <- NULL
  
  for (j in 1:ncol(S_1)){
    dat1 <- as.data.frame(cbind(Y_1[S_1[,j]],T_1[S_1[,j]],X_1[S_1[,j]]))
    names(dat1) <- c("Y","Tr","X")
    
    p1 <- sum(dat1$Y*dat1$Tr)/sum(dat1$Tr)
    p0 <- sum(dat1$Y*(1-dat1$Tr))/sum(1-dat1$Tr)
    
    mu1_1[j] <-  p1
    mu0_1[j] <- p0
    
    tau_1[j] <-  log(p1) - log(p0)
    
    sd_1.t[j]<- sd(dat1$Y[dat1$Tr==1])
    sd_1.c[j]<- sd(dat1$Y[dat1$Tr==0])
    
    sd_1_db.t[j]<- sqrt(p_1[j]*sd(dat1$Y[dat1$Tr==1])^2/(sum(dat1$Tr)/nrow(dat1)))
    sd_1_db.c[j]<- sqrt(p_1[j]*sd(dat1$Y[dat1$Tr==0])^2/(sum(1-dat1$Tr)/nrow(dat1)))
    
  }
  
  names(tau_1) <- c("A","B","C","D","E","F")
  names(sd_1.t) <- names(sd_1.c) <- c("A","B","C","D","E","F")
  names(sd_1_db.t) <- names(sd_1_db.c) <- c("A","B","C","D","E","F")
  
  tau_old <- tau_1
  sd_old.t <- sd_1.t
  sd_old.c <- sd_1.c
  sd_old_db.t <- sd_1_db.t
  sd_old_db.c <- sd_1_db.c
  S_old <- S_old.db <- S_1
  n_old <- m*6
  T_old <- T_old.cr <- T_old.db <- T_1
  X_old <- X_1
  Y_old <- Y_old.cr <- Y_old.db <- Y_1
  
  p_old <- p_1
  
  mu1_old <- mu1_1
  mu0_old <- mu0_1
  
  ## fully sequential
  for(i in 1:n){ # start of fully sequential assignments
    
    ## (1) Proposed design--------------------------------
    X_i <- GenX(1)
    S_i <- GenS(X_i)
    S_i <- apply(S_i,2,as.logical)
    names(S_i) <- c("A","B","C","D","E","F")
    S_new <- rbind(S_old,S_i)
    n_new <- n_old +1
    
    group_name <- names(S_i)[S_i]
    
    e_opt <- SubAlloc(tau_old,sd_old.t,sd_old.c,
                      envy_c,lower_c=0.01,upper_c=0.99,p_old,n_old,mu1_old,mu0_old)
    
    e_opt_i <- e_opt[group_name] # allocation for subject i
    
    # current allocation
    e_current_i <- sum(T_old[S_old[,group_name]])/length(T_old[S_old[,group_name]])
    
    T_i <- ifelse(e_current_i < e_opt_i,1,0)
    
    S_i <- t(as.data.frame(S_i))
    
    Y_i <-  GenY(1,X_i,T_i,S_i,true_p1,true_p0,true_p1_sd,true_p0_sd) # observe outcome of subject i
    
    # update old info
    T_old <- c(T_old, T_i)
    Y_old <- c(Y_old, Y_i)
    S_old <- S_new
    X_old <- c(X_old, X_i)
    n_old <- n_new
    
    # update tau and sd
    tau_old <- sd_old.t <-sd_old.c <- mu1_old <- mu0_old <- NULL
    for (j in 1:ncol(S_old)){
      dat1 <- as.data.frame(cbind(Y_old[S_old[,j]],T_old[S_old[,j]],X_old[S_old[,j]]))
      names(dat1) <- c("Y","Tr","X")
      
      p1 <- sum(dat1$Y*dat1$Tr)/sum(dat1$Tr)
      p0 <- sum(dat1$Y*(1-dat1$Tr)/sum(1-dat1$Tr))
      
      
      mu1_old[j] <- p1
      mu0_old[j] <- p0
      
      tau_old[j] <-  log(p1)-log(p0)
      
      sd_old.t[j]<- sd(dat1$Y[dat1$Tr==1])
      sd_old.c[j]<- sd(dat1$Y[dat1$Tr==0])
  
    }
    
    names(tau_old) <- c("A","B","C","D","E","F")
    names(sd_old.t) <- names(sd_old.c)<- c("A","B","C","D","E","F")
    
    sub_alloc_opt <- c(sum(T_old[S_old[,1]])/sum(S_old[,1]),
                       sum(T_old[S_old[,2]])/sum(S_old[,2]),
                       sum(T_old[S_old[,3]])/sum(S_old[,3]),
                       sum(T_old[S_old[,4]])/sum(S_old[,4]),
                       sum(T_old[S_old[,5]])/sum(S_old[,5]),
                       sum(T_old[S_old[,6]])/sum(S_old[,6])
                       
    )
    
    ## (2) Complete randomization--------------------------------
    e_cr <- rep(1/2,6)
    names(e_cr) <- c("A","B","C","D","E","F")
    T_i.cr <- rbinom(1,1,1/2)
    Y_i.cr <-  GenY(1,X_i,T_i.cr,S_i,true_p1,true_p0,true_p1_sd,true_p0_sd)
    # update old info
    T_old.cr <- c(T_old.cr, T_i.cr)
    Y_old.cr <- c(Y_old.cr, Y_i.cr)
    
    # update tau and sd
    tau_old.cr <- sd_old.cr.t <- sd_old.cr.c <- mu1_old.cr <- mu0_old.cr <-  NULL
    for (j in 1:ncol(S_old)){
      dat1 <- as.data.frame(cbind(Y_old.cr[S_old[,j]],T_old.cr[S_old[,j]],X_old[S_old[,j]]))
      names(dat1) <- c("Y","Tr","X")
      
      p1 <-  sum(dat1$Y*dat1$Tr)/sum(dat1$Tr)
      p0 <- sum(dat1$Y*(1-dat1$Tr)/sum(1-dat1$Tr))
      
      mu1_old.cr[j] <- p1
      mu0_old.cr[j] <- p0
      
      tau_old.cr[j] <-  log(p1)-log(p0)
      
      sd_old.cr.t[j]<- sd(dat1$Y*dat1$Tr)
      sd_old.cr.c[j]<- sd(dat1$Y*(1-dat1$Tr))
    }
    
    names(tau_old.cr) <- c("A","B","C","D","E","F")
    names(mu1_old.cr) <- names(mu0_old.cr) <- c("A","B","C","D","E","F")
    names(sd_old.cr.t) <- names(sd_old.cr.c) <-   c("A","B","C","D","E","F")
    
    sub_alloc_cr <- c(sum(T_old.cr[S_old[,1]])/sum(S_old[,1]),
                      sum(T_old.cr[S_old[,2]])/sum(S_old[,2]),
                      sum(T_old.cr[S_old[,3]])/sum(S_old[,3]),
                      sum(T_old.cr[S_old[,4]])/sum(S_old[,4]),
                      sum(T_old.cr[S_old[,5]])/sum(S_old[,5]),
                      sum(T_old.cr[S_old[,6]])/sum(S_old[,6]))
  
    
    ## (3) DBCD --------------------------------------------
    e_db <- sd_old_db.t/(sd_old_db.t + sd_old_db.c)
    e_db_i <- e_db[group_name] # allocation for subject i
    
    # current allocation
    e_db_current_i <- sum(T_old.db[S_old.db[,group_name]])/length(T_old.db[S_old.db[,group_name]])
    
    T_i.db <- ifelse(e_db_current_i < e_db_i,1,0)
    
    Y_i.db <-  GenY(1,X_i,T_i.db,S_i,true_p1,true_p0,true_p1_sd,true_p0_sd) # observe outcome of subject i
    
    # update old info
    T_old.db <- c(T_old.db, T_i.db)
    Y_old.db <- c(Y_old.db, Y_i.db)
    S_old.db <- S_new
    
    # update tau and sd
    tau_old.db <- sd_old_db.t <-sd_old_db.c <- mu1_old.db <- mu0_old.db <- NULL
    for (j in 1:ncol(S_old)){
      dat1 <- as.data.frame(cbind(Y_old.db[S_old[,j]],T_old.db[S_old[,j]],X_old[S_old[,j]]))
      names(dat1) <- c("Y","Tr","X")
      
      p1 <- sum(dat1$Y*dat1$Tr)/sum(dat1$Tr)
      p0 <- sum(dat1$Y*(1-dat1$Tr)/sum(1-dat1$Tr))
      
      mu1_old.db[j] <- p1
      mu0_old.db[j] <- p0
      
      tau_old.db[j] <-  log(p1)-log(p0)
      
      sd_old_db.t[j]<- sqrt(p_old[j]*sd(dat1$Y*dat1$Tr)^2/(sum(dat1$Tr)/nrow(dat1))/p1)
      sd_old_db.c[j]<- sqrt(p_old[j]*sd(dat1$Y*(1-dat1$Tr))^2/(sum(1-dat1$Tr)/nrow(dat1))/p0)
    }
    
    names(tau_old.db) <- c("A","B","C","D","E","F")
    names(sd_old_db.t) <- names(sd_old_db.c)<- c("A","B","C","D","E","F")
    
    sub_alloc_db <- c(sum(T_old.db[S_old[,1]])/sum(S_old[,1]),
                      sum(T_old.db[S_old[,2]])/sum(S_old[,2]),
                      sum(T_old.db[S_old[,3]])/sum(S_old[,3]),
                      sum(T_old.db[S_old[,4]])/sum(S_old[,4]),
                      sum(T_old.db[S_old[,5]])/sum(S_old[,5]),
                      sum(T_old.db[S_old[,6]])/sum(S_old[,6])
    )
    
    
    
  } # end of fully sequential assignments
  
  p_old <- colSums(S_old)/(nrow(S_old))
  
  ## Proposed design -----------------------------------------
  tau_opt <- tau_old
  ate_opt <- sum(tau_opt*p_old)
  ate_opt_1 <- sum(mu1_old*p_old)
  ate_opt_0 <- sum(mu0_old*p_old)
  
  nu_opt <- 1/p_old*(sd_old.t^2/sub_alloc_opt/mu1_old + sd_old.c^2/(1-sub_alloc_opt)/mu0_old)
  var_opt <-  sum(p_old^2*nu_opt) + (p_old[1]*(mu1_old[1]- ate_opt_1)^2/mu1_old[1]+
                                       p_old[1]*(mu0_old[1]- ate_opt_0)^2/mu0_old[1]+
                                       p_old[2]*(mu1_old[2]- ate_opt_1)^2/mu1_old[2]+
                                       p_old[2]*(mu0_old[2]- ate_opt_0)^2/mu0_old[2]+
                                       p_old[3]*(mu1_old[3]- ate_opt_1)^2/mu1_old[3]+
                                       p_old[3]*(mu0_old[3]- ate_opt_0)^2/mu0_old[3]+
                                       p_old[4]*(mu1_old[4]- ate_opt_1)^2/mu1_old[4]+
                                       p_old[4]*(mu0_old[4]- ate_opt_0)^2/mu0_old[4]+
                                       p_old[5]*(mu1_old[5]- ate_opt_1)^2/mu1_old[5]+
                                       p_old[5]*(mu0_old[5]- ate_opt_0)^2/mu0_old[5]+
                                       p_old[6]*(mu1_old[6]- ate_opt_1)^2/mu1_old[6]+
                                       p_old[6]*(mu0_old[6]- ate_opt_0)^2/mu0_old[6])
  
  sd_opt <- sqrt(var_opt)
  
  hi_opt <- ate_opt + 1.96*sd_opt/sqrt(m*6+n)
  lo_opt <- ate_opt - 1.96*sd_opt/sqrt(m*6+n)
  cover_opt <- (lo_opt < tau_true_all) & (hi_opt > tau_true_all)
  power_opt <- (hi_opt<0)
  
  ## Complete randomization ------------------------------------
  tau_cr <- tau_old.cr # log RR
  ate_cr <- sum(tau_cr*p_old)
  ate_cr_1 <- sum(mu1_old.cr*p_old)
  ate_cr_0 <- sum(mu0_old.cr*p_old)
  
  nu_cr <- 1/p_old*(sd_old.cr.t^2/sub_alloc_cr)/mu1_old.cr + 1/p_old*(sd_old.cr.c/(1-sub_alloc_cr))/mu0_old.cr
  
  var_cr <-  sum(p_old^2*nu_cr) + (p_old[1]*(mu1_old.cr[1]- ate_cr_1)^2/mu1_old.cr[1]+
                                     p_old[1]*(mu0_old.cr[1]- ate_cr_0)^2/mu0_old.cr[1]+
                                     p_old[2]*(mu1_old.cr[2]- ate_cr_1)^2/mu1_old.cr[2]+
                                     p_old[2]*(mu0_old.cr[2]- ate_cr_0)^2/mu0_old.cr[2]+
                                     p_old[3]*(mu1_old.cr[3]- ate_cr_1)^2/mu1_old.cr[3]+
                                     p_old[3]*(mu0_old.cr[3]- ate_cr_0)^2/mu0_old.cr[3]+
                                     p_old[4]*(mu1_old.cr[4]- ate_cr_1)^2/mu1_old.cr[4]+
                                     p_old[4]*(mu0_old.cr[4]- ate_cr_0)^2/mu0_old.cr[4]+
                                     p_old[5]*(mu1_old.cr[5]- ate_cr_1)^2/mu1_old.cr[5]+
                                     p_old[5]*(mu0_old.cr[5]- ate_cr_0)^2/mu0_old.cr[5]+
                                     p_old[6]*(mu1_old.cr[6]- ate_cr_1)^2/mu1_old.cr[6]+
                                     p_old[6]*(mu0_old.cr[6]- ate_cr_0)^2/mu0_old.cr[6])
  
  sd_cr <- sqrt(var_cr)
  
  hi_cr <- ate_cr + 1.96*sd_cr/sqrt(m*6+n)
  lo_cr <- ate_cr - 1.96*sd_cr/sqrt(m*6+n)
  cover_cr <- (lo_cr < tau_true_all) & (hi_cr > tau_true_all)
  power_cr <- (hi_cr<0)
  
  ## DBCD ------------------------------------
  tau_db <- tau_old.db
  ate_db <- sum(tau_db*p_old)
  ate_db_1 <- sum(mu1_old.db*p_old)
  ate_db_0 <- sum(mu0_old.db*p_old)
  
  nu_db <- sd_old_db.t^2 + sd_old_db.c^2
  var_db <-  sum(nu_db) + (p_old[1]*(mu1_old.db[1]- ate_db_1)^2/mu1_old.db[1]+
                             p_old[1]*(mu0_old.db[1]- ate_db_0)^2/mu0_old.db[1]+
                             p_old[2]*(mu1_old.db[2]- ate_db_1)^2/mu1_old.db[2]+
                             p_old[2]*(mu0_old.db[2]- ate_db_0)^2/mu0_old.db[2]+
                             p_old[3]*(mu1_old.db[3]- ate_db_1)^2/mu1_old.db[3]+
                             p_old[3]*(mu0_old.db[3]- ate_db_0)^2/mu0_old.db[3]+
                             p_old[4]*(mu1_old.db[4]- ate_db_1)^2/mu1_old.db[4]+
                             p_old[4]*(mu0_old.db[4]- ate_db_0)^2/mu0_old.db[4]+
                             p_old[5]*(mu1_old.db[5]- ate_db_1)^2/mu1_old.db[5]+
                             p_old[5]*(mu0_old.db[5]- ate_db_0)^2/mu0_old.db[5]+
                             p_old[6]*(mu1_old.db[6]- ate_db_1)^2/mu1_old.db[6]+
                             p_old[6]*(mu0_old.db[6]- ate_db_0)^2/mu0_old.db[6])
  sd_db <- sqrt(var_db)
  
  hi_db <- ate_db + 1.96*sd_db/sqrt(m*6+n)
  lo_db <- ate_db - 1.96*sd_db/sqrt(m*6+n)
  cover_db <- (lo_db < tau_true_all) & (hi_db > tau_true_all)
  power_db <- (hi_db<0)
  
  
  return(c(tau_opt,tau_cr,tau_db, #15
           sub_alloc_opt,sub_alloc_cr,sub_alloc_db, # 15
           ate_opt,ate_cr,ate_db, # 3
           cover_opt,cover_cr,cover_db,
           power_opt,power_cr,power_db))
}

